import io
import logging
from functools import partial
from pathlib import PurePath
from typing import Optional, Protocol

import paramiko

from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.credentials import Credentials, Password, SSHKeypair, Username, get_plaintext
from common.types import NetworkPort
from infection_monkey.i_puppet import TargetHost

logger = logging.getLogger(__name__)

SSH_AUTH_TIMEOUT = LONG_REQUEST_TIMEOUT
SSH_BANNER_TIMEOUT = MEDIUM_REQUEST_TIMEOUT
SSH_EXEC_TIMEOUT = LONG_REQUEST_TIMEOUT
SSH_CHANNEL_TIMEOUT = MEDIUM_REQUEST_TIMEOUT


class SSHConnectFunction(Protocol):
    def __call__(
        self, *, client: paramiko.SSHClient, host: TargetHost, port: NetworkPort, timeout: float
    ) -> None:
        ...


class SSHClient:
    def __init__(self):
        self._client = None
        self._authenticated = False
        self._percent_transferred_log_target = 0.1

    def connect(
        self, host: TargetHost, credentials: Credentials, port: NetworkPort, timeout: float
    ):
        """
        Connect to the host using SSH

        Credentials may be a username and password, or a username and private key.

        :param host: Host to connect to
        :param credentials: Credentials to use for the connection
        :param port: Port to connect to
        :param timeout: Timeout for the connection, in seconds
        :raises Exception: If the connection could not be established
        """

        connect_function = self._get_ssh_connection_function(credentials)
        client = paramiko.SSHClient()
        client.set_missing_host_key_policy(paramiko.WarningPolicy())

        try:
            connect_function(client=client, host=host, port=port, timeout=timeout)
            logger.debug(f"Successfully authenticated using SSH on host: {host.ip}")
        except Exception as err:
            client.close()
            raise err

        self._client = client
        self._authenticated = True

    def _get_ssh_connection_function(self, credentials: Credentials) -> SSHConnectFunction:
        if isinstance(credentials.identity, Username):
            username = credentials.identity.username
        else:
            message = "Unrecognised credential identity type"
            logger.debug(message)
            raise ValueError(message)

        if isinstance(credentials.secret, SSHKeypair):
            connect_function = partial(
                self._connect_with_private_key,
                username=username,
                private_key=credentials.secret.private_key,
            )
        elif isinstance(credentials.secret, Password):
            connect_function = partial(
                self._connect_with_password,
                username=username,
                password=credentials.secret.password,
            )
        else:
            message = "Unrecognised credential secret type"
            logger.debug(message)
            raise ValueError(message)

        return connect_function

    def _connect_with_private_key(
        self,
        client: paramiko.SSHClient,
        host: TargetHost,
        username: Optional[Username],
        private_key: str,
        port: NetworkPort,
        timeout: float,
    ):
        try:
            private_key_buffer = io.StringIO(get_plaintext(private_key))
            private_key_object = paramiko.RSAKey.from_private_key(private_key_buffer)
        except (IOError, paramiko.SSHException, paramiko.PasswordRequiredException) as err:
            logger.error("Failed reading SSH key")
            raise err

        try:
            client.connect(
                str(host.ip),
                username=username,
                pkey=private_key_object,
                port=int(port),
                timeout=timeout,
                auth_timeout=SSH_AUTH_TIMEOUT,
                banner_timeout=SSH_BANNER_TIMEOUT,
                channel_timeout=SSH_CHANNEL_TIMEOUT,
                allow_agent=False,
            )
            logger.debug(
                f"Successfully logged into {host.ip} using {username}@{host.ip} user's private key"
            )
        except paramiko.AuthenticationException as err:
            error_message = (
                f"Failed logging into victim {host.ip} with {username}@{host.ip} user's"
                f"private key: {err}"
            )
            logger.info(error_message)
            raise err
        except Exception as err:
            error_message = (
                f"Unexpected error while attempting to login to {username}@{host.ip} with SSH key: "
                f"{err}"
            )
            logger.error(error_message)
            raise err

    def _connect_with_password(
        self,
        client: paramiko.SSHClient,
        host: TargetHost,
        username: Optional[Username],
        password: str,
        port: NetworkPort,
        timeout: float,
    ):
        try:
            client.connect(
                str(host.ip),
                username=username,
                password=get_plaintext(password),
                port=int(port),
                timeout=timeout,
                auth_timeout=SSH_AUTH_TIMEOUT,
                banner_timeout=SSH_BANNER_TIMEOUT,
                channel_timeout=SSH_CHANNEL_TIMEOUT,
                allow_agent=False,
            )
            logger.debug(f"Successfully logged in {host.ip}, User: {username}")
        except paramiko.AuthenticationException as err:
            error_message = f"Failed logging into victim {host.ip} with user: {username}: {err}"
            raise err
        except Exception as err:
            error_message = (
                f"Unexpected error while attempting to login to {host.ip} with password: " f"{err}"
            )
            logger.debug(error_message)
            raise err

    def copy_file(
        self,
        file: bytes,
        destination_path: PurePath,
    ):
        """
        Copy a file to the remote host using SFTP

        :param file: File to copy
        :param destination_path: File destination path on the remote host
        :raises Exception: If the file copy failed
        """
        self._percent_transferred_log_target = 0.1

        try:
            with self._client.open_sftp() as sftp:  # type: ignore [union-attr]
                sftp.putfo(
                    io.BytesIO(file),
                    str(destination_path),
                    file_size=len(file),
                    callback=self._log_transfer,
                )
                sftp.chmod(str(destination_path), 0o700)
        except Exception as err:
            error_message = f"Error uploading file: ({err})"
            logger.error(error_message)
            raise err

    def _log_transfer(self, transferred: int, total: int):
        if (transferred / total) > self._percent_transferred_log_target:
            logger.debug(f"SFTP transferred: {transferred} bytes, total: {total} bytes")
            self._percent_transferred_log_target += 0.1

    def execute_command(self, command: str) -> bytes:
        """
        Execute a command on the remote host

        :param command: Command to execute
        :raises Exception: If the command execution failed
        :return: The command output
        """
        _, stdout, _ = self._client.exec_command(  # type: ignore [union-attr]
            command=command, timeout=SSH_EXEC_TIMEOUT
        )
        return stdout

    def connected(self) -> bool:
        return self._authenticated
